# -*- coding: utf-8 -*-
'''
Created on 2017年4月8日

@author: ZhuJiahui
'''

import os
import time
import numpy as np
from file_utils.file_writer import quick_write_1d_to_text
from file_utils.file_reader import read_to_1d_list_gbk

def random_mixture(read_directory, write_directory):
    
    selected_topic_index = [0, 1, 3, 4, 6, 7, 8, 9, 12, 14, 18, 19]
    
    train_text = []
    test_text = []
    each_selected_num = 4000
    each_train_num = 3000
    
    for each in selected_topic_index:
        this_file_text = read_to_1d_list_gbk(read_directory + '/' + str(each) + ".txt")
        shuffle_list = np.arange(len(this_file_text))
        np.random.shuffle(shuffle_list)
        for j in range(each_train_num):
            train_text.append(this_file_text[shuffle_list[j]])
        
        for j in range(each_train_num, each_selected_num):
            test_text.append(this_file_text[shuffle_list[j]])
    
    quick_write_1d_to_text(write_directory + "/train_mixture_text.txt", train_text)
    quick_write_1d_to_text(write_directory + "/test_mixture_text.txt", test_text)        

if __name__ == '__main__':
    start = time.clock()    
    now_directory = os.getcwd()
    root_directory = os.path.dirname(now_directory) + '/'
    read_directory = root_directory + u'dataset/sogou/text_category'
    write_directory = root_directory + u'dataset/sogou'
    
    if (not(os.path.exists(write_directory))):
        os.mkdir(write_directory)
    
    random_mixture(read_directory, write_directory)

    print('Total time %f seconds' % (time.clock() - start))
    